import math
from typing import Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
from torch import nn
from typing import Callable, List, Optional, Set, Tuple, Union
import inspect
import torchaudio.sox_effects as ta_sox
from torchaudio.compliance.kaldi import fbank
from torchaudio.functional import compute_kaldi_pitch
import soundfile as sf


###############################################################
###  ProsodyBERT Config
###############################################################

# # For LibriTTS
# config = ProsodyBertConfig(log_pitch_normalization=(5.169504057594512, 0.3896800908691966),
#                           energy_normalization=(1.042517152417171, 1.2270219395734532),
#                           frame_shift=20.0, target_sample_rate=16000, high_freq=500, n_bins=20)
# config.chunk_size_feed_forward = 0


class ProsodyBertConfig:

    def __init__(
        self,
        # vocab_size=30522,
        input_dim=24,
        output_size=100,
        num_conv_pos_embeddings=128,
        num_conv_pos_embedding_groups=16,
        span_max_length=200,
        span_pos_embed_dim=32,
        sinusoidal_pos_embds=False,
        n_layers=6,
        n_heads=8,
        dim=512,
        bottleneck_dim=32,
        hidden_dim=2048,
        dropout=0.1,
        attention_dropout=0.1,
        activation="gelu",
        initializer_range=0.02,
        qa_dropout=0.1,
        seq_classif_dropout=0.2,
        log_pitch_normalization = (0.0, 1.0),
        energy_normalization = (0.0, 1.0),
        frame_shift=20.0,
        target_sample_rate=16000,
        high_freq=500,
        n_bins=20
    ):
        # self.vocab_size = vocab_size
        self.input_dim = input_dim
        self.output_size = output_size
        self.num_conv_pos_embeddings = num_conv_pos_embeddings
        self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
        self.span_max_length = span_max_length
        self.span_pos_embed_dim = span_pos_embed_dim
        self.sinusoidal_pos_embds = sinusoidal_pos_embds
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dim = dim
        self.bottleneck_dim = bottleneck_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation = activation
        self.initializer_range = initializer_range
        self.qa_dropout = qa_dropout
        self.seq_classif_dropout = seq_classif_dropout

        # feature extractor configs
        self.log_pitch_normalization = log_pitch_normalization
        self.energy_normalization = energy_normalization
        self.frame_shift = frame_shift
        self.target_sample_rate = target_sample_rate
        self.high_freq = high_freq
        self.n_bins = n_bins


###############################################################
###  ProsodyBERT Feature Extractor
###############################################################

class ProsodyBertFeatureExtractor(nn.Module):

    def __init__(
        self,
        config
        ):
        super().__init__()
        self.frame_shift = config.frame_shift
        self.target_sample_rate = config.target_sample_rate
        self.high_freq = config.high_freq
        self.n_bins = config.n_bins
        self.log_pitch_mean, self.log_pitch_scale = config.log_pitch_normalization
        self.energy_mean, self.energy_scale = config.energy_normalization

        self.hashes = {}

    def cmvn(self, vec, variance_normalization=False):
        """ This function is aimed to perform global cepstral mean and
        variance normalization (CMVN) on input feature vector "vec".
        The code assumes that there is one observation per row.
        Args:
            vec (array): input feature matrix
            (size:(num_observation,num_features))
            variance_normalization (bool): If the variance
            normilization should be performed or not.
        Return:
            array: The mean(or mean+variance) normalized feature vector.
        """
        eps = 2**-30
        rows, cols = vec.shape

        # Mean calculation
        norm = torch.mean(vec, 0)
        norm_vec = torch.tile(norm, (rows, 1))

        # Mean subtraction
        mean_subtracted = vec - norm_vec

        # Variance normalization
        if variance_normalization:
            stdev = torch.std(mean_subtracted, 0)
            stdev_vec = torch.tile(stdev, (rows, 1))
            output = mean_subtracted / (stdev_vec + eps)
        else:
            output = mean_subtracted

        return output


    def process_single_signal(
        self, 
        waveform: torch.tensor,
        sample_rate=16000,
        ):

        # weird hash lol
        hash_name = "_".join([str(float(n)) for n in waveform[:10]])
        if hash_name in self.hashes:
            return self.hashes[hash_name]
        
        target_sample_rate = self.target_sample_rate
        waveform = waveform.T
        
        # apply effect. add volume normalization
        effects = [["gain", "-n"]]
        
        # change freq if not same
        if target_sample_rate != sample_rate:
            effects.append(["rate", str(target_sample_rate)])
        
        waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, effects)
        
        
        # compute kaldi pitch
        pitch_input = compute_kaldi_pitch(waveform, sample_rate=target_sample_rate, frame_shift=self.frame_shift)[0]
        
        # compute fbank
        features = fbank(waveform, num_mel_bins=self.n_bins, sample_frequency=target_sample_rate, 
                         use_energy=True, high_freq=self.high_freq, frame_shift=self.frame_shift)
        
        mel_input = features[:, 1:]
        energy_input = features[:, 0]
        
        # normalize
        mel_input = self.cmvn(mel_input, variance_normalization=True)
        pitch_input[:,1] = (torch.log(pitch_input[:,1]) - self.log_pitch_mean) / self.log_pitch_scale
        energy_input = ((energy_input - self.energy_mean) / self.energy_scale).view(-1,1)
        
        pitch_delta = torch.diff(pitch_input[:,1], append=torch.tensor([0.,])).view(-1,1)
        common_len = min(pitch_input.shape[0], energy_input.shape[0], mel_input.shape[0])

        source = torch.cat((pitch_input[:common_len], pitch_delta[:common_len], energy_input[:common_len], mel_input[:common_len]), -1)
        source = source.detach()

        self.hashes[hash_name] = source

        return source
        

    def forward(self, raw_speechs, sample_rate=16000, max_len=None):
        
        if type(raw_speechs) != list:
            raw_speechs = [raw_speechs]

        sources = [self.process_single_signal(raw_speech, sample_rate) for raw_speech in  raw_speechs]

        batch_size = len(sources)
        
        batch_max_len = max_len
        if max_len is None:
            batch_max_len = 0
            for s in sources:
                batch_max_len = max(batch_max_len, s.shape[0])

        feat_dim = sources[0].shape[-1]

        inputs_values = torch.zeros(batch_size, batch_max_len, feat_dim)
        attention_mask = torch.zeros(batch_size, batch_max_len, dtype=torch.long)

        for batch_i in range(batch_size):
            source = sources[batch_i][:batch_max_len]
            len_source = min(source.shape[0], batch_max_len)

            attention_mask[batch_i, :len_source] = 1
            inputs_values[batch_i, :len_source, :] = source

        return {"inputs_values": inputs_values, "attention_mask": attention_mask}



###############################################################
###  Modeling ProsodyBERT
###############################################################

def apply_chunking_to_forward(
    forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
    assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"

    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
    if num_args_in_forward_chunk_fn != len(input_tensors):
        raise ValueError(
            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
            "tensors are given"
        )

    if chunk_size > 0:
        tensor_shape = input_tensors[0].shape[chunk_dim]
        for input_tensor in input_tensors:
            if input_tensor.shape[chunk_dim] != tensor_shape:
                raise ValueError(
                    f"All input tenors have to be of the same shape: {tensor_shape}, "
                    f"found shape {input_tensor.shape[chunk_dim]}"
                )

        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
            raise ValueError(
                f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
                f"size {chunk_size}"
            )

        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size

        # chunk input tensor into tuples
        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        # apply forward fn to every tuple
        output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
        # concatenate output at same dimension
        return torch.cat(output_chunks, dim=chunk_dim)

    return forward_fn(*input_tensors)


def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:

    mask = torch.ones(n_heads, head_size)
    heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
    for head in heads:
        # Compute how many pruned heads are before the head and move the index accordingly
        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
        mask[head] = 0
    mask = mask.view(-1).contiguous().eq(1)
    index: torch.LongTensor = torch.arange(len(mask))[mask].long()
    return heads, index



def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads: Set[int] = set()

    def prune_heads(self, heads: List[int]):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x: torch.Tensor) -> torch.Tensor:
            """separate heads"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x: torch.Tensor) -> torch.Tensor:
            """group heads"""
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores = scores.masked_fill(
            mask, torch.tensor(torch.finfo(scores.dtype).min)
        )  # (bs, n_heads, q_length, k_length)

        weights = nn.functional.softmax(scores, dim=-1)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)


class FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
        self.activation = nn.GELU()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        assert config.dim % config.n_heads == 0

        self.attention = MultiHeadSelfAttention(config)
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

        self.ffn = FFN(config)
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim)
            attn_mask: torch.tensor(bs, seq_length)

        Returns:
            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
        """
        # Self-Attention
        sa_output = self.attention(
            query=x,
            key=x,
            value=x,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        if output_attentions:
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            assert type(sa_output) == tuple
            sa_output = sa_output[0]
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
        ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (sa_weights,) + output
        return output


class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_layers = config.n_layers
        self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: Optional[bool] = None,
    ):  # docstyle-ignore
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = x
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            layer_outputs = layer_module(
                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
            )
            hidden_state = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions,)
            else:
                assert len(layer_outputs) == 1

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)

        return {
            "last_hidden_state": hidden_state,
            "hidden_states": all_hidden_states,
            "attentions": all_attentions
        }


class ProsodybertSamePadLayer(nn.Module):
    def __init__(self, num_conv_pos_embeddings):
        super().__init__()
        self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0

    def forward(self, hidden_states):
        if self.num_pad_remove > 0:
            hidden_states = hidden_states[:, :, : -self.num_pad_remove]
        return hidden_states

# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Prosodybert
class ProsodybertPositionalConvEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.conv = nn.Conv1d(
            config.dim,
            config.dim,
            kernel_size=config.num_conv_pos_embeddings,
            padding=config.num_conv_pos_embeddings // 2,
            groups=config.num_conv_pos_embedding_groups,
        )

        self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)

        self.padding = ProsodybertSamePadLayer(config.num_conv_pos_embeddings)
        self.activation = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = hidden_states.transpose(1, 2)

        hidden_states = self.conv(hidden_states)
        hidden_states = self.padding(hidden_states)
        hidden_states = self.activation(hidden_states)

        hidden_states = hidden_states.transpose(1, 2)
        return hidden_states



class ProsodyBertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config=config

        self.input_projector = nn.Linear(config.input_dim, config.dim)
        self.pos_embeddings = ProsodybertPositionalConvEmbedding(config) #PosEmbeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder

        # mask embedding
        self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.dim).uniform_())

        # bottleneck layer
        self.bottleneck = nn.Linear(config.dim, config.bottleneck_dim)

    def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)
    
    def get_head_mask(
        self, head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
    ) -> torch.Tensor:
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
        else:
            head_mask = [None] * num_hidden_layers

        return head_mask

    def forward(
        self,
        inputs_values: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        input_shape = inputs_values.size()[:-1]

        device = inputs_values.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.n_layers)

        # add positional embeding and project
        inputs_values = self.input_projector(inputs_values)

        inputs_values = inputs_values + self.pos_embeddings(inputs_values) # (bs, seq_length, dim)

        # mask by time indices
        if mask_time_indices is not None:
            inputs_values[mask_time_indices] = self.masked_spec_embed.to(inputs_values.dtype)

        outputs =  self.transformer(
            x=inputs_values,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        outputs["last_hidden_state"] = self.bottleneck(outputs["last_hidden_state"])

        return outputs